<!DOCTYPE html>
<html>
  <head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes">
    <title>WebGPU Compute Shaders - Histogram, slow</title>
    <style>
      @import url(resources/webgpu-lesson.css);
      canvas {
        display: block;
        max-width: 256px;
        border: 1px solid #888;
        background-color: #333;
      }
    </style>
  </head>
  <body>
    <canvas width="256" height="256"></canvas>
  </body>
  <script type="module">
import TimingHelper from './resources/js/timing-helper.js';
// see https://webgpufundamentals.org/webgpu/lessons/webgpu-utils.html#webgpu-utils
import {
  loadImageBitmap,
  createTextureFromSource,
} from '../3rdparty/webgpu-utils-1.x.module.js';

async function main() {
  const adapter = await navigator.gpu?.requestAdapter();
  const canTimestamp = adapter?.features.has('timestamp-query');
  const device = await adapter?.requestDevice({
    requiredFeatures: [
      ...(canTimestamp ? ['timestamp-query'] : []),
    ],
  });
  if (!device) {
    fail('need a browser that supports WebGPU');
    return;
  }

  const timingHelper = new TimingHelper(device);

  // Get a WebGPU context from the canvas and configure it
  const canvas = document.querySelector('canvas');
  const context = canvas.getContext('webgpu');
  const presentationFormat = navigator.gpu.getPreferredCanvasFormat();
  context.configure({
    device,
    format: presentationFormat,
  });

  const module = device.createShaderModule({
    label: 'histogram shader',
    code: /* wgsl */ `
      @group(0) @binding(0) var<storage, read_write> histogram: array<vec4u>;
      @group(0) @binding(1) var ourTexture: texture_2d<f32>;

      // from: https://www.w3.org/WAI/GL/wiki/Relative_luminance
      const kSRGBLuminanceFactors = vec3f(0.2126, 0.7152, 0.0722);
      fn srgbLuminance(color: vec3f) -> f32 {
        return saturate(dot(color, kSRGBLuminanceFactors));
      }

      @compute @workgroup_size(1, 1, 1) fn cs() {
        let size = textureDimensions(ourTexture, 0);
        let numBins = f32(arrayLength(&histogram));
        let lastBinIndex = u32(numBins - 1);
        for (var y = 0u; y < size.y; y++) {
          for (var x = 0u; x < size.x; x++) {
            let position = vec2u(x, y);
            let color = textureLoad(ourTexture, position, 0);
            let luminance = srgbLuminance(color.rgb);
            for (var ch = 0; ch < 4; ch++) {
              let v = select(color[ch], luminance, ch == 3);
              let bin = min(u32(v * numBins), lastBinIndex);
              histogram[bin][ch] += 1;
            }
          }
        }
      }
    `,
  });

  const pipeline = device.createComputePipeline({
    label: 'histogram',
    layout: 'auto',
    compute: {
      module,
    },
  });

  const rangeModule = device.createShaderModule({
    label: 'compute histogram range shader',
    code: /* wgsl */ `
      struct HistogramInfo {
        min: vec4u,
        max: vec4u,
      };

      @group(0) @binding(0) var<storage, read> histogram: array<vec4u>;
      @group(0) @binding(1) var<storage, read_write> histogramInfo: HistogramInfo;

      @compute @workgroup_size(1) fn cs(
        @builtin(local_invocation_id) local_invocation_id: vec3u
      ) {
        var minV = histogram[0];
        var maxV = histogram[0];
        for (var i = 1u; i < 256u; i++) {
          let v = histogram[i];
          minV = min(minV, v);
          maxV = max(maxV, v);
        }
        histogramInfo.min = minV;
        histogramInfo.max = maxV;
      }
    `,
  });
  const rangePipeline = device.createComputePipeline({
    label: 'histogram range',
    layout: 'auto',
    compute: {
      module: rangeModule,
    },
  });

  const renderModule = device.createShaderModule({
    label: 'draw histogram shader',
    code: /* wgsl */ `
      struct VSOutput {
        @builtin(position) position: vec4f,
        @location(0) texcoord: vec2f,
      };

      struct HistogramInfo {
        min: vec4u,
        max: vec4u,
      };

      @group(0) @binding(0) var<storage, read_write> histogram: array<vec4u>;
      @group(0) @binding(1) var<storage, read> histogramInfo: HistogramInfo;

      @vertex fn vs(
        @builtin(vertex_index) vertexIndex : u32
      ) -> VSOutput {
        let pos = array(
           vec2f(-1, -1),
           vec2f(-1,  3),
           vec2f( 3, -1),
        );

        var vsOutput: VSOutput;
        let xy = pos[vertexIndex];
        vsOutput.position = vec4f(xy, 0.0, 1.0);
        vsOutput.texcoord = xy * vec2f(0.5, -0.5) + vec2f(0.5);
        return vsOutput;
      }

      @fragment fn fs(vsOutput: VSOutput) -> @location(0) vec4f {
        // convert texcoord.x to histogram index
        let ndx = u32(min(255, vsOutput.texcoord.x * 256));

        // convert texcoord.y to histogram range
        let height = vec4u(vsOutput.texcoord.y * vec4f(histogramInfo.max));

        let count = histogram[ndx];
        let color = select(vec4f(0), vec4f(1), height >= count);
        return color;
      }
    `,
  });

  const renderPipeline = device.createRenderPipeline({
    label: 'histogram render pipeline',
    layout: 'auto',
    vertex: {
      module: renderModule,
    },
    fragment: {
      module: renderModule,
      targets: [{ format: presentationFormat }],
    },
  });

  const imgBitmap = await loadImageBitmap('resources/images/pexels-francesco-ungaro-96938-mid.jpg'); /* webgpufundamentals: url */
  const texture = createTextureFromSource(device, imgBitmap);

  const histogramBuffer = device.createBuffer({
    size: 256 * 4 * 4, // 256 entries * 4 (rgba) * 4 bytes per (u32)
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
  });

  const resultBuffer = device.createBuffer({
    size: histogramBuffer.size,
    usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
  });

  const bindGroup = device.createBindGroup({
    label: 'histogram bindGroup',
    layout: pipeline.getBindGroupLayout(0),
    entries: [
      { binding: 0, resource: { buffer: histogramBuffer }},
      { binding: 1, resource: texture.createView() },
    ],
  });

  const histogramInfoBuffer = device.createBuffer({
    size: 32, // vec4f * 2
    usage: GPUBufferUsage.STORAGE,
  });

  const rangeBindGroup = device.createBindGroup({
    label: 'range bindGroup',
    layout: rangePipeline.getBindGroupLayout(0),
    entries: [
      { binding: 0, resource: { buffer: histogramBuffer }},
      { binding: 1, resource: { buffer: histogramInfoBuffer } },
    ],
  });

  const renderBindGroup = device.createBindGroup({
    label: 'render bindGroup',
    layout: renderPipeline.getBindGroupLayout(0),
    entries: [
      { binding: 0, resource: { buffer: histogramBuffer }},
      { binding: 1, resource: { buffer: histogramInfoBuffer } },
    ],
  });

  const renderPassDescriptor = {
    label: 'our basic canvas renderPass',
    colorAttachments: [
      {
        // view: <- to be filled out when we render
        clearValue: [0, 0, 0, 1],
        loadOp: 'clear',
        storeOp: 'store',
      },
    ],
  };

  // Get the current texture from the canvas context and
  // set it as the texture to render to.
  renderPassDescriptor.colorAttachments[0].view =
      context.getCurrentTexture().createView();

  const encoder = device.createCommandEncoder({
    label: 'histogram encoder',
  });
  {
    const pass = timingHelper.beginComputePass(encoder);
    pass.setPipeline(pipeline);
    pass.setBindGroup(0, bindGroup);
    pass.dispatchWorkgroups(1);
    pass.end();
  }

  {
    const pass = encoder.beginComputePass();
    pass.setPipeline(rangePipeline);
    pass.setBindGroup(0, rangeBindGroup);
    pass.dispatchWorkgroups(1);
    pass.end();
  }

  {
    const pass = encoder.beginRenderPass(renderPassDescriptor);
    pass.setPipeline(renderPipeline);
    pass.setBindGroup(0, renderBindGroup);
    pass.draw(3);
    pass.end();
  }

  encoder.copyBufferToBuffer(histogramBuffer, 0, resultBuffer, 0, resultBuffer.size);

  const commandBuffer = encoder.finish();
  device.queue.submit([commandBuffer]);

  timingHelper.getResult().then(duration => {
    console.log(`duration: ${duration}ns`);
  });

  //await resultBuffer.mapAsync(GPUMapMode.READ);
  //const result = new Uint32Array(resultBuffer.getMappedRange());

  showImageBitmap(imgBitmap);
}

function showImageBitmap(imageBitmap) {
  const canvas = document.createElement('canvas');
  canvas.width = imageBitmap.width;
  canvas.height = imageBitmap.height;

  const bm = canvas.getContext('bitmaprenderer');
  bm.transferFromImageBitmap(imageBitmap);
  document.body.appendChild(canvas);
}

function fail(msg) {
  // eslint-disable-next-line no-alert
  alert(msg);
}

main();
  </script>
</html>
